P4284 [SHOI2014]概率充电器


一道换根树形概率期望dp好题


题解:

先对问题进行转化

因为期望的线性性,所以整棵树期望几个被充电就等价于所有节点各自被充电的期望的和,又由于对于每个节点,期望都等于概率×1,所以其实题目就是在问你所有节点能被充上电的概率的和

方便起见,考虑从反面进行思考,即求每个节点的不被充上电的概率

记点$x$自己亮的概率为$q[x]$,x,y之间通电的概率为$e[x][y]$

设$g[x]$表示点x不亮的概率

容易得到$g[x]=(1-q[x])\prod\limits_{(x,y)\in E}(1-e[x][y]+e[x][y]g[y])$

方程的意思是(自己不会亮自己)且((与之相连的点不会导电)或(与之相连的点会导电但没电))

注意到这个方程是有后效性的,所以先确定一个全局根,再设$f[x]$表示点x不被以x为根子树内任何一个点点亮的概率,得到新方程$f[x]=(1-q[x])\prod\limits_{y \in son(x)}(1-e[x][y]+e[x][y]f[y])$

发现这个方程每在树上跑一遍,每次所选的全局根的g[x]都因为等于f[x]而得到

所以只要跑n遍dp,每次换个全局根就可以算出来所有g[x]了?

然而因为这个方程要求的东西显然是可以分成上贡献和下贡献两部分,且仅仅是乘起来的,所以我们考虑换根树形dp

在第二次dfs的时候跑g[x]的转移方程,这要求算出实际的g[fa],因为现实的g[fa]是重复计算了x这部分子树的贡献的,所以除去后带进方程,就可以得到g[x],而至于多余的部分,显然是可以用f[x]算出来的,也就是$g[fa]_{实际}=g[fa]/(1-e[fa][x]+e[fa][x]*f[x])$

算完所有的g[x]后,容易得出$ans=\sum\limits_{i=1}^{n}(1-g[i])$


代码:

#include <bits/stdc++.h>
using namespace std;
template<class t> inline t read(t &x){
    x=0;char c=getchar();bool f=0;
    while(!isdigit(c)) f|=c=='-',c=getchar();
    while(isdigit(c)) x=(x<<1)+(x<<3)+(c^48),c=getchar();
    if(f) x=-x;return  x;
}
template<class t> inline void write(t x){
    if(x<0){putchar('-'),write(-x);}
    else{if(x>9)write(x/10);putchar('0'+x%10);}
}

const int N=5e5+5;
int en,h[N],n;
double ans,g[N],f[N],p[N];

struct edge{
    int n,v;
    double w;
}e[N<<1];

void add(int x,int y,int z){
    e[++en]=(edge){h[x],y,0.01*z};
    h[x]=en;
}

void dfs1(int x,int fa){
    f[x]=1-p[x];
    for(int i=h[x];i;i=e[i].n){
        int y=e[i].v;
        if(y==fa) continue;
        dfs1(y,x);
        f[x]*=1-e[i].w+e[i].w*f[y];
    }
}

void dfs2(int x,int fa,int past){
    if(x==1) g[x]=f[x];
    else{
        double cfa=g[fa]/(1-e[past].w+e[past].w*f[x]);//contribution from ancestor:祖先的贡献
        g[x]=f[x]*(1-e[past].w+e[past].w*cfa);
    }
    for(int i=h[x];i;i=e[i].n){
        int y=e[i].v;
        if(y==fa) continue;
        dfs2(y,x,i);
    }
}

signed main(){
    read(n);
    for(int i=1,x,y,z;i<n;i++){
        read(x);read(y);read(z);
        add(x,y,z);
        add(y,x,z);
    }
    for(int i=1,x;i<=n;i++) p[i]=0.01*read(x);
    dfs1(1,0);
    dfs2(1,0,0);
    for(int i=1;i<=n;i++) ans+=1-g[i];
    printf("%.6lf",ans);
}

fighter